"""
DMLC submission script, SLURM version
"""

# pylint: disable=invalid-name
from __future__ import absolute_import

import logging
import subprocess
from threading import Thread

from . import tracker


def get_mpi_env(envs):
    """get the slurm command for setting the environment"""
    cmd = ""
    for k, v in envs.items():
        cmd += "%s=%s " % (k, str(v))
    return cmd


def submit(args):
    """Submission script with SLURM."""

    def mpi_submit(nworker, nserver, pass_envs):
        """Internal closure for job submission."""

        def run(prog):
            """run the program"""
            subprocess.check_call(prog, shell=True)

        cmd = " ".join(args.command)

        pass_envs["DMLC_JOB_CLUSTER"] = "slurm"

        if args.slurm_worker_nodes is None:
            nworker_nodes = nworker
        else:
            nworker_nodes = args.slurm_worker_nodes

        # start workers
        if nworker > 0:
            logging.info("Start %d workers by srun" % nworker)
            pass_envs["DMLC_ROLE"] = "worker"
            prog = "%s srun --share --exclusive=user -N %d -n %d %s" % (
                get_mpi_env(pass_envs),
                nworker_nodes,
                nworker,
                cmd,
            )
            thread = Thread(target=run, args=(prog,))
            thread.setDaemon(True)
            thread.start()

        if args.slurm_server_nodes is None:
            nserver_nodes = nserver
        else:
            nserver_nodes = args.slurm_server_nodes

        # start servers
        if nserver > 0:
            logging.info("Start %d servers by srun" % nserver)
            pass_envs["DMLC_ROLE"] = "server"
            prog = "%s srun --share --exclusive=user -N %d -n %d %s" % (
                get_mpi_env(pass_envs),
                nserver_nodes,
                nserver,
                cmd,
            )
            thread = Thread(target=run, args=(prog,))
            thread.setDaemon(True)
            thread.start()

    tracker.submit(
        args.num_workers,
        args.num_servers,
        fun_submit=mpi_submit,
        pscmd=(" ".join(args.command)),
    )
